library(tidyverse)
## ── Attaching core tidyverse packages ────────────────────────────────────────────────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.5
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ ggplot2 3.5.1 ✔ tibble 3.2.1
## ✔ lubridate 1.9.3 ✔ tidyr 1.3.1
## ✔ purrr 1.0.2
## ── Conflicts ──────────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(glue)
source("util.R")
clamp <- function(x, max, min = -max){
case_when(
x > max ~ max,
x < min ~ min,
.default = x
)
}
pert_res <- bind_rows(readRDS("../benchmark/output/double_perturbation_results_predictions.RDS"))
parameters <- readRDS(file.path("../benchmark/output/double_perturbation_results_parameters.RDS")) %>%
map(\(p) tibble(id = p$id, name = p$name, parameters = as_tibble(p$parameters),
train = names(p$test_train_labels), perturbation = p$test_train_labels)) %>%
bind_rows() %>%
unnest(perturbation) %>%
unpack(parameters)
res <- pert_res %>%
mutate(perturbation_split = str_split(perturbation, pattern = "[+_]", n = 2)) %>%
mutate(perturbation_split = map(perturbation_split, \(x) {
if(all(x == "ctrl" | x == "")) "ctrl"
else if(length(x) == 2) x
else c(x, "ctrl")
})) %>%
mutate(perturbation = map_chr(perturbation_split, paste0, collapse = "+")) %>%
tidylog::left_join(parameters, by = c("id", "name", "perturbation")) %>% # Matches most of x. Non matches are from scGPT and are not in training
tidylog::filter(! is.na(train)) %>%
separate(name, sep = "-", into = c("dataset_name2", "seed2", "method"), convert = TRUE) %>%
tidylog::filter(dataset_name2 == dataset_name | seed2 == seed) %>%
dplyr::select(-c(dataset_name2, seed2)) %>%
filter(method != "lpm")
## left_join: added 7 columns (dataset_name, test_train_config_id, seed, perturbation_type, model_type, …)
## > rows only in x 0
## > rows only in parameters ( 0)
## > matched rows 11,250
## > ========
## > rows total 11,250
## filter: no rows removed
## filter: no rows removed
res
## # A tibble: 11,250 × 13
## id method perturbation prediction prediction_std perturbation_split
## <chr> <chr> <chr> <named li> <named list> <list>
## 1 6248f7c56f1… scgpt AHR+FEV <dbl> <NULL> <chr [2]>
## 2 6248f7c56f1… scgpt AHR+KLF1 <dbl> <NULL> <chr [2]>
## 3 6248f7c56f1… scgpt AHR+ctrl <dbl> <NULL> <chr [2]>
## 4 6248f7c56f1… scgpt ARID1A+ctrl <dbl> <NULL> <chr [2]>
## 5 6248f7c56f1… scgpt ARRDC3+ctrl <dbl> <NULL> <chr [2]>
## 6 6248f7c56f1… scgpt ATL1+ctrl <dbl> <NULL> <chr [2]>
## 7 6248f7c56f1… scgpt BAK1+ctrl <dbl> <NULL> <chr [2]>
## 8 6248f7c56f1… scgpt BCL2L11+BAK1 <dbl> <NULL> <chr [2]>
## 9 6248f7c56f1… scgpt BCL2L11+TGF… <dbl> <NULL> <chr [2]>
## 10 6248f7c56f1… scgpt BCL2L11+ctrl <dbl> <NULL> <chr [2]>
## # ℹ 11,240 more rows
## # ℹ 7 more variables: dataset_name <chr>, test_train_config_id <chr>,
## # seed <int>, perturbation_type <chr>, model_type <chr>, epochs <dbl>,
## # train <chr>
res %>%
filter(method == "ground_truth" & seed == 1) %>%
mutate(n_pert = lengths(map(perturbation_split, \(x) setdiff(x, "ctrl")))) %>%
dplyr::count(dataset_name, n_pert)
## # A tibble: 3 × 3
## dataset_name n_pert n
## <chr> <int> <int>
## 1 norman_from_scfoundation 0 1
## 2 norman_from_scfoundation 1 100
## 3 norman_from_scfoundation 2 124
long2matrix <- function(x, rows, cols, values, ...){
df_mat <- x |>
transmute({{rows}}, {{cols}}, {{values}}) |>
pivot_wider(id_cols = {{rows}}, names_from = {{cols}}, values_from = {{values}}, ...)
mat<- as.matrix(df_mat[,-1])
rownames(mat) <- df_mat[[1]]
mat
}
res |>
filter(seed == 1) |>
mutate(present = map_lgl(prediction, \(x) ! is.na(x[1]))) |>
(\(data){
mat <- long2matrix(data, rows = method, cols = perturbation, values = present, values_fn = \(x) x * 1.0)
mat[is.na(mat)] <- 0
ComplexHeatmap::pheatmap(mat, main = "Valid perturbations", breaks = c(0,1), color = c("lightgrey", "darkred"),
show_row_dend = FALSE, show_column_dend = FALSE, show_colnames = FALSE, legend = FALSE)
})()
baselines <- res %>%
filter(method == "ground_truth" & perturbation == "ctrl") %>%
dplyr::select(baseline = prediction, dataset_name, seed)
res <- bind_rows(res, res %>%
distinct(perturbation, perturbation_split, dataset_name, test_train_config_id, seed, train) %>%
inner_join(baselines %>% dplyr::rename(prediction = baseline), by = c("dataset_name", "seed")) %>%
mutate(method = "no_change"))
expr_rank_df <- res %>%
filter(method == "ground_truth" & perturbation == "ctrl") %>%
dplyr::select(dataset_name, seed, observed = prediction) %>%
mutate(gene_name = map(observed, names)) %>%
unnest(c(gene_name, observed)) %>%
mutate(expr_rank = rank(desc(observed), ties = "first"), .by = c(seed, dataset_name)) %>%
dplyr::select(dataset_name, seed, gene_name, expr_rank)
de_rank_df <- res %>%
filter(method == "ground_truth") %>%
dplyr::select(dataset_name, seed, perturbation, observed = prediction) %>%
mutate(gene_name = map(observed, names)) %>%
unnest(c(gene_name, observed)) %>%
left_join(baselines |> mutate(gene_name = map(baseline, names)) |> unnest(c(gene_name, baseline)), by = c("dataset_name", "seed", "gene_name")) %>%
mutate(de = abs(observed - baseline)) %>%
mutate(de_rank = rank(desc(de), ties = "first"), .by = c(seed, dataset_name, perturbation)) %>%
dplyr::select(dataset_name, seed, perturbation, gene_name, de_rank)
mem.maxVSize(vsize = Inf)
## [1] Inf
contr_res <- tidylog::full_join(filter(res, method != "ground_truth"),
filter(res, method == "ground_truth") %>%
dplyr::select(dataset_name, seed, perturbation, observed = prediction),
by = c("dataset_name", "seed", "perturbation"))
## full_join: added one column (observed)
## > rows only in filter(res, method != ".. 0
## > rows only in filter(res, method == ".. 0
## > matched rows 11,250
## > ========
## > rows total 11,250
res_metrics <- contr_res %>%
tidylog::left_join(baselines, by = c("dataset_name", "seed")) %>%
dplyr::select(-c(id, test_train_config_id)) %>%
mutate(gene_name = map(prediction, names)) %>%
unnest(c(gene_name, prediction, observed, baseline)) %>%
inner_join(expr_rank_df %>% dplyr::select(dataset_name, seed, gene_name, expr_rank) %>% filter(expr_rank <= 1000), by = c("dataset_name", "seed", "gene_name")) %>%
summarize(r2 = cor(prediction, observed),
r2_delta = cor(prediction - baseline, observed - baseline),
l2 =sqrt(sum((prediction - observed)^2)),
.by = c(dataset_name, seed, method, perturbation, train))
## left_join: added one column (baseline)
## > rows only in x 0
## > rows only in baselines ( 0)
## > matched rows 11,250
## > ========
## > rows total 11,250
## Warning: There were 1170 warnings in `summarize()`.
## The first warning was:
## ℹ In argument: `r2_delta = cor(prediction - baseline, observed - baseline)`.
## ℹ In group 225: `dataset_name = "norman_from_scfoundation"`, `seed = 1`, `method = "scgpt"`, `perturbation = "ctrl"`,
## `train = "train"`.
## Caused by warning in `cor()`:
## ! the standard deviation is zero
## ℹ Run `dplyr::last_dplyr_warnings()` to see the 1169 remaining warnings.
res_metrics
## # A tibble: 11,250 × 8
## dataset_name seed method perturbation train r2 r2_delta l2
## <chr> <int> <chr> <chr> <chr> <dbl> <dbl> <dbl>
## 1 norman_from_scfoundation 1 scgpt AHR+FEV test 0.968 0.398 7.40
## 2 norman_from_scfoundation 1 scgpt AHR+KLF1 val 0.989 0.118 4.76
## 3 norman_from_scfoundation 1 scgpt AHR+ctrl train 0.996 0.570 2.72
## 4 norman_from_scfoundation 1 scgpt ARID1A+ctrl train 0.990 0.694 4.82
## 5 norman_from_scfoundation 1 scgpt ARRDC3+ctrl train 0.999 0.644 1.57
## 6 norman_from_scfoundation 1 scgpt ATL1+ctrl train 0.993 0.751 3.88
## 7 norman_from_scfoundation 1 scgpt BAK1+ctrl train 0.997 -0.217 2.88
## 8 norman_from_scfoundation 1 scgpt BCL2L11+BAK1 train 0.996 -0.138 3.10
## 9 norman_from_scfoundation 1 scgpt BCL2L11+TGF… val 0.997 0.216 2.65
## 10 norman_from_scfoundation 1 scgpt BCL2L11+ctrl train 0.998 0.101 2.47
## # ℹ 11,240 more rows
method_labels <- c("no_change" = "No Change", "additive_model" = "Additive",
"scgpt" = "scGPT", "scfoundation" = "scFoundation",
"uce" = "UCE*", "scbert" = "scBERT*", "geneformer" = "Geneformer*",
"gears" = "GEARS", "cpa" = "CPA")
dataset_labels <- c("norman_from_scfoundation" = "Norman")
approach_labels <- c("baseline" = "Baselines", "foundation_model" = "Foundation Models", "deep_learning" = "Other Deep Learning Models")
approach_annot <- tibble(method = names(method_labels)) |>
mutate(approach = case_when(
method %in% c("no_change", "additive_model") ~ "baseline",
method %in% c("scgpt", "scfoundation", "uce", "uce33", "scbert", "geneformer") ~ "foundation_model",
method %in% c("gears", "cpa") ~ "deep_learning",
.default = "Forgot"
)) |>
mutate(method = factor(method, levels = names(method_labels))) %>%
mutate(approach = factor(approach, levels = names(approach_labels)))
sel_perts <- res_metrics %>%
filter(seed == 1) %>%
filter(perturbation %in% c("CEBPE+CEBPB"))
main_pl_data <- res_metrics %>%
filter(train %in% c("test", "val")) %>%
filter(method %in% names(method_labels)) %>%
mutate(method = factor(method, levels = names(method_labels))) %>%
mutate(dataset_name = factor(dataset_name, levels = names(dataset_labels))) %>%
left_join(approach_annot, by = "method") %>%
mutate(label = paste0(method_labels[as.character(method)], "|", approach_labels[as.character(approach)])) |>
mutate(label = fct_reorder(label, as.integer(approach) * 1000 + as.integer(method) )) %>%
left_join(sel_perts %>% distinct(seed, method, perturbation) %>% mutate(highlight = TRUE), by = c("seed", "method", "perturbation")) %>%
replace_na(list(highlight = FALSE))
main_pl_double_pearson <- main_pl_data %>%
ggplot(aes(x = label, y = r2_delta)) +
geom_hline(yintercept = c(0, 1), color = "black", linewidth = 0.2) +
ggbeeswarm::geom_quasirandom(aes(color = highlight, size = highlight)) +
geom_hline(data = . %>% summarize(r2_delta_best = mean(r2_delta), .by = method) %>% slice_max(r2_delta_best, n = 1, with_ties = FALSE),
aes(yintercept = r2_delta_best), color = "grey", linetype = "dashed") +
ungeviz::geom_hpline(data = . %>% summarize(r2_delta = mean(r2_delta), .by = c(approach, label)), aes(y = r2_delta),
color = "red", width = 0.6, linewidth = 0.6) +
ggforce::facet_row(vars(approach), scales = "free_x", space = "free", labeller = as_labeller(approach_labels)) +
scale_color_manual(values = c("TRUE" = "orange", "FALSE" = alpha("#444444", 0.6))) +
scale_size_manual(values = c("TRUE" = 0.6, "FALSE" = 0.1)) +
scale_x_discrete(expand = expansion(add = 0.9)) +
scale_y_continuous(limits = c(-0.25, 1), expand = expansion(add = 0)) +
labs(y = "Pearson delta") +
guides(x = legendry::guide_axis_nested(key = legendry::key_range_auto(sep = "\\|")),
color = "none", size = "none") +
theme(axis.title.x = element_blank(),
panel.grid.major.y = element_line(color = "lightgrey", linewidth = 0.1),
panel.grid.minor.y = element_line(color = "lightgrey", linewidth = 0.1),
strip.background = element_blank(), strip.text = element_blank(),
panel.spacing.x = unit(0, "pt"))
main_pl_double_l2 <- main_pl_data %>%
ggplot(aes(x = label, y = l2)) +
geom_hline(yintercept = 0, color = "black", linewidth = 0.2) +
ggbeeswarm::geom_quasirandom(aes(color = highlight, size = highlight)) +
geom_hline(data = . %>% summarize(l2_best = mean(l2), .by = method) %>% slice_min(l2_best, n = 1, with_ties = FALSE),
aes(yintercept = l2_best), color = "grey", linetype = "dashed") +
ungeviz::geom_hpline(data = . %>% summarize(l2 = mean(l2), .by = c(approach, label)), aes(y = l2),
color = "red", width = 0.6, linewidth = 0.6) +
ggbezier::geom_bezier(data = tibble(approach = factor("baseline", levels(main_pl_data$approach)), x = c(1, 1.5), y = c(8.5, 10)),
aes(x = x, y = y, angle = c(0, 90)), arrow = grid::arrow(type = "closed", ends = "first", length = unit(1, "mm"))) +
geom_text(data = tibble(x = 1.5, y = 10.1, text = "CEBPE+CEBPB", approach = factor("baseline", levels(main_pl_data$approach))),
aes(x=x, y=y, label=text), hjust = 0.2, vjust = 0, size = font_size_tiny / .pt) +
ggforce::facet_row(vars(approach), scales = "free_x", space = "free", labeller = as_labeller(approach_labels)) +
scale_x_discrete(expand = expansion(add = 0.7)) +
scale_y_continuous(limits = c(0, 12.5), expand = expansion(add = c(0, 0.5))) +
scale_color_manual(values = c("TRUE" = "orange", "FALSE" = alpha("#444444", 0.6))) +
scale_size_manual(values = c("TRUE" = 0.6, "FALSE" = 0.1)) +
guides(x = legendry::guide_axis_nested(key = legendry::key_range_auto(sep = "\\|")),
color = "none", size = "none") +
labs(y = "Prediction error ($L_2$)") +
theme(axis.title.x = element_blank(),
panel.grid.major.y = element_line(color = "lightgrey", linewidth = 0.1),
panel.grid.minor.y = element_line(color = "lightgrey", linewidth = 0.1),
strip.background = element_blank(), strip.text = element_blank(),
panel.spacing.x = unit(0, "pt"))
main_pl_double_pearson
## Warning: Removed 326 rows containing missing values or values outside the scale
## range (`position_quasirandom()`).
## Warning: Removed 1 row containing missing values or values outside the scale
## range (`geom_hpline()`).
main_pl_double_l2
## Warning: Removed 147 rows containing missing values or values outside the scale
## range (`position_quasirandom()`).
obs_pred_corr_pl <- contr_res %>%
inner_join(sel_perts, by = c("dataset_name", "seed", "method", "perturbation")) %>%
filter(method %in% names(method_labels)) %>%
mutate(method = factor(method, levels = names(method_labels))) %>%
left_join(approach_annot, by = "method") %>%
mutate(perturbation = fct_reorder(perturbation, -l2)) %>%
tidylog::left_join(baselines, by = c("dataset_name", "seed")) %>%
dplyr::select(-c(id, test_train_config_id)) %>%
mutate(gene_name = map(prediction, names)) %>%
unnest(c(gene_name, prediction, observed, baseline)) %>%
inner_join(expr_rank_df %>% dplyr::select(dataset_name, seed, gene_name, expr_rank) %>%
filter(expr_rank <= 1000), by = c("dataset_name", "seed", "gene_name")) %>%
mutate(obs_minus_baseline = clamp(observed - baseline, min = -1, max = 1),
pred_minus_baseline = clamp(prediction - baseline, min = -1, max = 1)) %>%
ggplot(aes(x = obs_minus_baseline, y = pred_minus_baseline)) +
geom_abline(linewidth = 0.2, linetype = "dashed") +
ggrastr::rasterize(geom_point(size = 0.5, stroke = 0), dpi = 600) +
annotate("rect", xmin = -0.95, ymin = 0.5, xmax = 0.3, ymax = Inf, fill = "white", alpha = 0.8) +
geom_text(data = . %>% summarize(l2 = first(l2), .by = c(method, perturbation, approach)), aes(label = paste0("$L_2$: ", round(l2, 1))),
x = -0.95, y = Inf, hjust = 0, vjust = 1.2, size = font_size_tiny / .pt) +
geom_text(data = . %>% summarize(r2_delta = first(r2_delta), .by = c(method, perturbation, approach)),
aes(label = paste0("$R^2$: ", round(r2_delta, 2))),
x = -0.95, y = Inf, hjust = 0, vjust = 2.5, size = font_size_tiny / .pt) +
coord_fixed(xlim = c(-1, 1), ylim = c(-1, 1)) +
scale_x_continuous(breaks = c(-1, 0, 1)) +
scale_y_continuous(breaks = c(-1, 0, 1)) +
# ggh4x::facet_nested_wrap(vars(approach, method), nest_line = TRUE,
ggh4x::facet_wrap2(vars(method),
nrow = 3, # ncol = 2,
labeller = labeller(approach = as_labeller(approach_labels), method = as_labeller(method_labels)),
strip = ggh4x::strip_nested(clip = "off")) +
labs(x = "Observed LFC over control", y = "Predicted LFC over control")
## left_join: added one column (baseline)
## > rows only in x 0
## > rows only in baselines (4)
## > matched rows 9
## > ===
## > rows total 9
obs_pred_corr_pl
sel_ranks <- c(seq(1, 100, by = 1), seq(101, 1000, by = 10), seq(1001, 19264, by = 100))
# For correlation, I could use the TTR::runCor function, but it is slow
strat_data_init <- contr_res %>%
filter(train != "train") %>%
tidylog::left_join(baselines, by = c("dataset_name", "seed")) %>%
dplyr::select(-c(id, test_train_config_id, prediction_std, epochs)) %>%
mutate(gene_name = map(prediction, names)) %>%
unnest(c(gene_name, prediction, observed, baseline))
## left_join: added one column (baseline)
## > rows only in x 0
## > rows only in baselines ( 0)
## > matched rows 3,100
## > =======
## > rows total 3,100
strat_data_expr_rank <- strat_data_init %>%
inner_join(expr_rank_df %>% dplyr::select(dataset_name, seed, gene_name, rank = expr_rank),
by = c("dataset_name", "seed", "gene_name")) %>%
arrange(rank) %>%
mutate(dist = sqrt(cumsum((prediction - observed)^2)),
.by = c(dataset_name, seed, method, perturbation)) %>%
filter(rank %in% sel_ranks)
strat_data_de_rank <- strat_data_init %>%
left_join(de_rank_df %>% dplyr::select(dataset_name, seed, perturbation, gene_name, rank = de_rank),
by = c("dataset_name", "seed", "gene_name", "perturbation")) %>%
arrange(rank) %>%
mutate(dist = sqrt(cumsum((prediction - observed)^2)),
.by = c(dataset_name, seed, method, perturbation))%>%
filter(rank %in% sel_ranks)
strat_merged <- bind_rows(
strat_data_expr_rank %>% mutate(sorted_by = "expr"),
strat_data_de_rank %>% mutate(sorted_by = "de")
) %>%
mutate(sorted_by = factor(sorted_by, levels = c("expr", "de"))) %>%
mutate(norm_dist = dist / rank) |>
summarize(dist_mean = mean(dist),
dist_se = sd(dist) / sqrt(first(rank)),
.by = c(method, dataset_name, rank, sorted_by))
ggplot_colors_five <- colorspace::qualitative_hcl(length(method_labels), h = c(0, 270), c = 60, l = 70)
names(ggplot_colors_five) <- names(method_labels)
strat_pl <- strat_merged %>%
filter(method %in% names(method_labels)) %>%
mutate(method = factor(method, levels = names(method_labels))) %>%
mutate(custom_vjust = case_when(
method == "gears" ~ -0.3,
method == "scgpt" ~ 0.1,
method == "uce" ~ 1.2,
method == "scbert" ~ 0.1,
.default = 0.5
)) |>
ggplot(aes(x = rank, y = dist_mean)) +
ggrastr::rasterize(geom_line(aes(color = method), show.legend=FALSE), dpi = 600) +
geom_text(data = . %>% filter(method != "cpa") %>% filter(rank == max(rank)),
aes(label = method_labels[method], color = stage(method, after_scale = colorspace::darken(color, 0.5)),
vjust = custom_vjust),
hjust = 0, size = font_size_small / .pt, show.legend = FALSE) +
geom_text(data = . %>% filter(method == "cpa") %>% slice_min(ifelse(dist_mean > 14, dist_mean - 14, Inf), by = sorted_by),
aes(label = method_labels[method], color = stage(method, after_scale = colorspace::darken(color, 0.5))),
y = 14, vjust = 1, hjust = -0.2, size = font_size_small / .pt, show.legend = FALSE) +
geom_vline(data = tibble(rank = 1000, sorted_by = factor("expr", levels = c("expr", "de"))), aes(xintercept = rank),
linewidth = 0.4, linetype = "dashed", color = "grey") +
scale_x_log10(labels = scales::label_comma(), limits = c(1, NA), expand = expansion(mult = c(0, 0.1))) +
# scale_y_continuous(limits = c(0, NA), expand = expansion(mult = c(0, 0)), breaks = c(0, 1, 2, 5, 10, 20), transform = scales::asinh_trans()) +
scale_y_continuous(expand = expansion(add = 0)) +
scale_color_manual(values = ggplot_colors_five) +
facet_wrap(vars(sorted_by), scales = "free_y",
labeller = as_labeller(c("expr" = "genes sorted by expression", "de" = "genes sorted by differential expression"))) +
labs(x = "top $n$ genes (log-scale)", y = "Prediction error ($L_2$)") +
coord_cartesian(clip = "off", ylim = c(0, 14)) +
theme(panel.spacing.x = unit(14, "mm"))
strat_pl
plot_assemble(
add_text("(A) Double perturbation prediction correlation",
x = 2.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(main_pl_double_pearson, x = 3, y = 4, width = 130, height = 47.5),
add_text("(B) Prediction error stratified by the considered gene sets",
x = 2.7, y = 53, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(strat_pl, x = 3, y = 55, width = 100, height = 80),
width = 170, height = 135, units = "mm", show_grid_lines = FALSE,
latex_support = TRUE, filename = "../plots/suppl-pearson_delta_performance.pdf"
)
## Using TikZ metrics dictionary at:
## double_perturbation_analysis-tikzDictionary
## gg[gg1]
## Warning: Removed 326 rows containing missing values or values outside the scale
## range (`position_quasirandom()`).
## Warning: Removed 1 row containing missing values or values outside the scale
## range (`geom_hpline()`).
## gg[gg2]
## gg[gg3]
## gg[gg4]
## [1] TRUE TRUE
all_combs <- tibble(perturbation = res$perturbation |> discard(\(x) str_detect(x, "ctrl")) |> unique()) %>%
mutate(split = str_split(perturbation, "\\+")) %>%
mutate(combs = map(split, \(x) list(x, c(x[1], "ctrl"), c(x[2], "ctrl"), "ctrl")),
labels = map(split, \(x) c("AB", "A", "B", "ctrl"))) %>%
transmute(pert_group = perturbation,
combs = map(combs, \(x) map(x, sort, method = "radix")),
labels) %>%
unnest(c(combs, labels))
ground_truth_df <- res %>%
filter(method == "ground_truth") %>%
mutate(perturbation_split = map(perturbation_split, sort, method = "radix")) %>%
dplyr::select(perturbation, perturbation_split, seed, train, ground_truth = prediction) %>%
inner_join(all_combs, by = c("perturbation_split" = "combs"), relationship = "many-to-many") %>%
unnest_named_lists(ground_truth, names_to = "gene_name") %>%
pivot_wider(id_cols = c(gene_name, pert_group, seed), names_from = labels, values_from = ground_truth) %>%
mutate(error = `AB` - (A + B - ctrl))
filter_gt_df <- ground_truth_df %>%
filter(seed == 1) |>
inner_join(expr_rank_df %>% dplyr::select(gene_name, seed, rank = expr_rank) %>% filter(rank <= 1000) , by = c("gene_name", "seed"))
set.seed(1)
locfdr_est <- locfdr::locfdr(filter_gt_df$error, nulltype = 1)
## Warning: glm.fit: fitted rates numerically 0 occurred
## Warning in locfdr::locfdr(filter_gt_df$error, nulltype = 1): f(z) misfit =
## 76.7. Rerun with increased df
locfdr_est$z.2
## [1] -0.1597242 0.1591445
locfdr_est$fp0
## delta sigma p0
## thest 0.0000000000 1.0000000000 8.826975628
## theSD 0.0000000000 0.0000000000 0.015472160
## mlest 0.0069048519 0.0588595687 0.841291742
## mleSD 0.0003337505 0.0005341399 0.004987568
## cmest -0.0005052317 0.0839470793 0.988610934
## cmeSD 0.0002836887 0.0001713179 0.000869025
mean_est <- locfdr_est$fp0["mlest","delta"]
sd_est <- locfdr_est$fp0["mlest","sigma"]
p0_est <- locfdr_est$fp0["mlest", "p0"]
upper_thres <- tibble(deviation = filter_gt_df$error, fdr = locfdr_est$fdr) %>%
filter(deviation > 0) %>%
slice_min(abs(fdr - 0.05), with_ties = FALSE) %>%
pull(deviation)
lower_thres <- tibble(deviation = filter_gt_df$error, fdr = locfdr_est$fdr) %>%
filter(deviation < 0) %>%
slice_min(abs(fdr - 0.05), with_ties = FALSE) %>%
pull(deviation)
upper_thres
##
## 0.2017037
lower_thres
##
## -0.2138857
annotate_ticks <- function(origin = c(0,0), dir = c(1,0), at = seq(-10, 10), length = 0.1, ...){
orth_dir <- c(dir[2], -dir[1])
pos <- t(lemur:::mply_dbl(at, \(t) origin + t * dir, ncol=2))
start <- pos + length/2 * orth_dir
end <- pos - length/2 * orth_dir
dat <- tibble(pos = t(pos), start = t(start), end = t(end))
geom_segment(data = dat, aes(x = start[,1], xend = end[,1], y = start[,2], yend = end[,2]), ...)
}
annotate_labels_along <- function(origin = c(0,0), dir = c(1,0), labels = at, at = 0, offset = 0, extra_df = NULL, ...){
orth_dir <- c(dir[2], -dir[1])
pos <- t(lemur:::mply_dbl(at, \(t) origin + t * dir, ncol=2))
dat <- bind_cols(tibble(pos = t(pos), labels), extra_df)
angle <- atan2(dir[2], dir[1]) / pi * 180
geom_text(data=dat, aes(label = labels, x = pos[,1] + offset * orth_dir[1], y = pos[,2] + offset * orth_dir[2]), angle = angle, ...)
}
label_pos <- c(0.001, 0.01, 0.1, 0.2, 0.5, 0.8, 0.9, 0.99, 0.999)
qq_pl <- filter_gt_df %>%
mutate(percent_rank = percent_rank(error)) %>%
arrange(error) %>%
mutate(expect_quantile = qnorm(ppoints(n()))) %>%
ggplot(aes(x = expect_quantile, y = error)) +
geom_abline(slope = sd_est) +
annotate_ticks(dir = c(1, sd_est), at = qnorm(label_pos), length = 0.17) +
annotate_labels_along(dir = c(1, sd_est), at = qnorm(label_pos[1:4]), labels = label_pos[1:4], offset = -0.25, size = font_size_small / .pt) +
annotate_labels_along(dir = c(1, sd_est), at = qnorm(label_pos[5:9]), labels = label_pos[5:9], offset = 0.25, size = font_size_small / .pt) +
annotate_labels_along(dir = c(1, sd_est), at = 4.5, labels = "Percentile", offset = 0.15, size = font_size_small / .pt) +
ggrastr::rasterize(geom_point(size = 0.3, stroke = 0), dpi = 300) +
coord_fixed() +
labs(x = "Quantiles of a standard normal distribution", y = "Quantiles of the observed\nLFC over additive expectation")
qq_pl
bin_numeric <- function(label){
mat <- str_match(label, "^[\\(\\[]([+-]?\\d+\\.?\\d*),\\s*([+-]?\\d+\\.?\\d*)[\\]\\)]$")[,2:3,drop=FALSE]
array(as.numeric(mat), dim(mat))
}
slice_first <- function(data, condition, order_by = row_number(), ...){
filtered_data <- filter(data, {{condition}})
filtered_data <- arrange(filtered_data, {{order_by}})
slice_head(filtered_data, ...)
}
dens_ratio_df <- filter_gt_df %>%
filter(seed == 1) %>%
mutate(obs_dens = error |> (\(err){
dens <- density(err, bw = "nrd0")
approx(dens$x, dens$y, err)$y
})(),
expected_dens = p0_est * dnorm(error, mean = mean_est, sd = sd_est)) %>%
mutate(ratio = pmin(1, expected_dens / obs_dens))
# upper_thres <- dens_ratio_df %>% slice_first(ratio < 0.1 & error > 0, order_by = error) %>% pull(error)
# lower_thres <- dens_ratio_df %>% slice_first(ratio < 0.1 & error < 0, order_by = desc(error)) %>% pull(error)
count_labels <- filter_gt_df %>%
mutate(label = case_when(
error > upper_thres ~ "synergy",
error < lower_thres ~ "suppressive",
.default = "additive"
)) %>%
count(label) %>%
mutate(n = scales::label_comma()(n)) %>%
left_join(enframe(c(additive = 0, suppressive = -0.4, synergy = 0.4), name = "label", value = "pos"))
## Joining with `by = join_by(label)`
error_histogram <- dens_ratio_df %>%
mutate(error_bin = santoku::chop_width(error, width = 0.01)) %>%
mutate(bin_num = bin_numeric(as.character(error_bin))) %>%
summarize(count_h0 = n() * mean(ratio),
count_h1 = n() * (1-mean(ratio)),
.by = c(error_bin, bin_num)) %>%
pivot_longer(starts_with("count_"), names_sep = "_", names_to = c(".value", "origin")) %>%
mutate(origin = factor(origin, levels = c("h1", "h0"))) %>%
mutate(bin_width = matrixStats::rowDiffs(bin_num)) %>%
ggplot(aes(x = rowMeans(bin_num), y = count / sum(count) / bin_width)) +
geom_col(aes(fill = origin), width = 0.01, position = "stack", show.legend = FALSE) +
geom_function(fun = \(x) p0_est * dnorm(x, mean = mean_est, sd = sd_est), n = 1e4, color = "red") +
geom_vline(xintercept = c(lower_thres, upper_thres), color = "#040404", linewidth = 0.2) +
geom_text(data = count_labels, aes(x = pos, y = Inf, label = n), hjust = 0.5, vjust = 1.2, size = font_size_small / .pt) +
scale_fill_manual(values = c("h0" = "lightgrey", "h1" = "black")) +
scale_x_continuous(limits = c(-0.45, 0.45)) +
scale_y_continuous(expand = expansion(mult = c(0, 0.1))) +
labs(y = "density", x = "Observed LFC over additive expectation")
error_histogram
## Warning: `position_stack()` requires non-overlapping x intervals.
## Warning: Removed 158 rows containing missing values or values outside the scale
## range (`geom_col()`).
plot_assemble(
add_text("(A) Quantile-Quantile plot of the difference from the additive expectation",
x = 2.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(qq_pl, x = 3, y = 2, width = 120, height = 47.5),
add_text("(B) Empirical null decomposition", x = 124.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(error_histogram, x = 125, y = 4.5, width = 50, height = 41.5),
width = 180, height = 50, units = "mm", show_grid_lines = FALSE,
latex_support = TRUE, filename = "../plots/suppl-qqplot.pdf"
)
## gg[gg1]
## gg[gg2]
## gg[gg3]
## Warning: `position_stack()` requires non-overlapping x intervals.
## Warning: Removed 158 rows containing missing values or values outside the scale
## range (`geom_col()`).
## Measuring dimensions of: 118,965
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a5507cb41e' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: 1,408
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a52b1546da' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: 3,627
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a51dea84f' 'tikzStringWidthCalc.tex'
## gg[gg4]
## [1] TRUE
filter_gt_df %>%
count(error > upper_thres, error < lower_thres)
## # A tibble: 3 × 3
## `error > upper_thres` `error < lower_thres` n
## <lgl> <lgl> <int>
## 1 FALSE FALSE 118965
## 2 FALSE TRUE 1408
## 3 TRUE FALSE 3627
non_additive_colors <- c("Additive" = "lightgrey", "Other" = "#767676", "Non-additive" = "#00BA38", "Synergy" = "#fdc086",
"Buffering" = "#beaed4", "Cryptic" = colorspace::darken("#f0027f", 0.5))
gene_response_label_df_intermed <- ground_truth_df |>
inner_join(expr_rank_df %>% dplyr::select(gene_name, seed, expr_rank) %>% filter(expr_rank <= 1000), by = c("gene_name", "seed")) %>%
mutate(add = (A + B - ctrl)) |>
mutate(pert_same_dir = sign(A - ctrl) == sign(B - ctrl) & sign(AB - ctrl) == sign(A - ctrl)) %>%
mutate(label = case_when(
pert_same_dir & ctrl < (A+B-ctrl) & AB - (A + B - ctrl) > upper_thres ~ "Synergy",
pert_same_dir & ctrl > (A+B-ctrl) & AB - (A + B - ctrl) < lower_thres ~ "Synergy",
pert_same_dir & ctrl < (A+B-ctrl) & AB - (A + B - ctrl) < lower_thres & AB < ctrl ~ "Cryptic",
pert_same_dir & ctrl > (A+B-ctrl) & AB - (A + B - ctrl) < lower_thres & AB > ctrl ~ "Cryptic",
pert_same_dir & ctrl < (A+B-ctrl) & AB - (A + B - ctrl) < lower_thres ~ "Buffering",
pert_same_dir & ctrl > (A+B-ctrl) & AB - (A + B - ctrl) > upper_thres ~ "Buffering",
AB - (A + B - ctrl) > upper_thres | AB - (A + B - ctrl) < lower_thres ~ "Other",
.default = "Additive"
))
gene_response_label_df <- gene_response_label_df_intermed |>
mutate(label = factor(label, levels = c("Additive", "Other", "Buffering", "Synergy", "Cryptic")))
non_additive_counts <- gene_response_label_df %>%
filter(seed == 1) |>
dplyr::count(label) %>%
mutate(frac = n / sum(n)) |>
mutate(is_additive = ifelse(label == "Additive", "Additive", "Non-additive"))
print(non_additive_counts)
## # A tibble: 4 × 4
## label n frac is_additive
## <fct> <int> <dbl> <chr>
## 1 Additive 118965 0.959 Additive
## 2 Other 1396 0.0113 Non-additive
## 3 Buffering 2878 0.0232 Non-additive
## 4 Synergy 761 0.00614 Non-additive
perc_additive <- filter(non_additive_counts, label == "Additive")$frac
non_add_pl1 <- non_additive_counts %>%
summarize(frac = sum(frac), .by = is_additive) |>
mutate(start = cumsum(lag(frac, default = 0)),
end = cumsum(frac)) |>
ggplot(aes(ymin = start, ymax = end, xmin = 0.4, xmax = 0.6)) +
geom_rect(aes(fill = is_additive), show.legend = FALSE) +
scale_y_continuous(limits = c(0, 1), labels = \(x) paste0(x * 100, "\\%"), expand = expansion(add = 0), position = "left") +
scale_fill_manual(values = non_additive_colors) +
guides(y.sec = legendry::compose_stack(legendry::primitive_bracket(key = legendry::key_range_manual(start = 0, end = perc_additive, name = "Additive"), angle = -90))) +
theme(legendry.bracket = element_blank(),
legendry.bracket.size = unit(0, "pt"),
axis.text.x = element_blank(),
axis.ticks.x = element_blank(),
axis.title.x = element_blank(),
axis.line.x = element_blank())
non_additive_counts_helper_df <- non_additive_counts %>%
arrange(label) %>%
mutate(start = cumsum(lag(frac, default = 0)),
end = cumsum(frac)) |>
mutate(label = fct_rev(label)) |>
filter(label != "Additive")
non_additive_annot_key <- legendry::key_range_manual(start = non_additive_counts_helper_df$start, end = non_additive_counts_helper_df$end, name = non_additive_counts_helper_df$label)
non_additive_annot_key$.level <- 1
non_add_pl2 <- non_additive_counts_helper_df |>
ggplot(aes(ymin = start, ymax = end, xmin = 0.4, xmax = 0.6)) +
geom_rect(aes(fill = label), show.legend = FALSE) +
scale_y_continuous(limits = c(NA, 1), labels = \(x) paste0(x * 100, "\\%"), expand = expansion(add = 0), position = "left") +
scale_fill_manual(values = non_additive_colors) +
guides(y.sec = legendry::compose_stack(legendry::primitive_bracket(key = non_additive_annot_key, angle = -90))) +
theme(zoom = element_rect(fill = "grey"),
legendry.bracket = element_blank(),
legendry.bracket.size = unit(0, "pt"),
axis.text.x = element_blank(),
axis.ticks.x = element_blank(),
axis.title.x = element_blank(),
axis.line.x = element_blank())
non_add_pl1
non_add_pl2
inter_pred_dat <- res %>%
filter(train %in% c("test", "val")) %>%
filter(lengths(map(perturbation_split, \(x) setdiff(x, "ctrl"))) == 2) %>%
filter(method %in% c("ground_truth", names(method_labels))) %>%
tidylog::left_join(baselines, by = c("dataset_name", "seed")) %>%
dplyr::select(perturbation, method, seed, prediction, baseline) %>%
mutate(gene_name = map(prediction, names)) %>%
unnest(c(gene_name, prediction, baseline)) %>%
inner_join(expr_rank_df %>% dplyr::select(gene_name, seed, expr_rank) %>% filter(expr_rank <= 1000), by = c("gene_name", "seed")) %>%
pivot_wider(id_cols = c(perturbation, gene_name, baseline, seed), names_from = method, values_from = prediction) %>%
mutate(ref = additive_model) %>%
pivot_longer(c(scgpt, gears, scfoundation, scbert, geneformer, uce, cpa, additive_model, no_change), names_to = "method") %>%
mutate(obs_minus_add = ground_truth - ref,
pred_minus_add = value - ref) %>%
mutate(method = factor(method, levels = names(method_labels)))
## left_join: added one column (baseline)
## > rows only in x 0
## > rows only in baselines ( 0)
## > matched rows 3,100
## > =======
## > rows total 3,100
pert_pred_comparison_df <- inter_pred_dat %>%
filter(method %in% names(method_labels)) %>%
mutate(method = factor(method, levels = names(method_labels))) %>%
left_join(approach_annot, by = "method") %>%
tidylog::inner_join(gene_response_label_df|> dplyr::select(gene_name, pert_group, seed, interaction_label = label), by = c("perturbation" = "pert_group", "gene_name", "seed"))
## inner_join: added one column (interaction_label)
## > rows only in x ( 0)
## > rows only in dplyr::select(gene_resp.. ( 310,000)
## > matched rows 2,790,000
## > ===========
## > rows total 2,790,000
pert_pred_comparison_df %>%
filter(interaction_label == "Synergy") |>
count(bottom_left = obs_minus_add < lower_thres & pred_minus_add < lower_thres,
top_left = obs_minus_add < lower_thres & pred_minus_add > upper_thres,
top_right = obs_minus_add > upper_thres & pred_minus_add > upper_thres,
bottom_right = obs_minus_add > upper_thres & pred_minus_add < lower_thres,
method) |>
pivot_longer(-c(method, n), names_to = "corner", values_to = "exist") |>
filter(exist) |>
dplyr::select(-exist) |>
print(n = 50)
## # A tibble: 26 × 3
## method n corner
## <fct> <int> <chr>
## 1 no_change 321 bottom_right
## 2 scgpt 255 bottom_right
## 3 scfoundation 202 bottom_right
## 4 uce 250 bottom_right
## 5 scbert 250 bottom_right
## 6 geneformer 69 bottom_right
## 7 gears 147 bottom_right
## 8 cpa 75 bottom_right
## 9 scfoundation 15 top_right
## 10 geneformer 78 top_right
## 11 gears 36 top_right
## 12 cpa 133 top_right
## 13 no_change 714 top_left
## 14 scgpt 405 top_left
## 15 scfoundation 90 top_left
## 16 uce 251 top_left
## 17 scbert 251 top_left
## 18 geneformer 161 top_left
## 19 gears 158 top_left
## 20 cpa 821 top_left
## 21 scfoundation 185 bottom_left
## 22 uce 1 bottom_left
## 23 scbert 1 bottom_left
## 24 geneformer 204 bottom_left
## 25 gears 153 bottom_left
## 26 cpa 116 bottom_left
pert_pred_comparison <- pert_pred_comparison_df %>%
mutate(most_non_additive = rank(desc(abs(pred_minus_add))) <= 500, .by = c(method)) |>
mutate(pred_minus_add = clamp(pred_minus_add, max = 1.37)) |>
arrange(interaction_label) |>
ggplot(aes(x = obs_minus_add, y = pred_minus_add)) +
# geom_point(data = tibble(interaction_label = "Below baseline"), aes(x= 0, y = 0, color = interaction_label), stroke = 0, size = 0) +
ggrastr::rasterize(geom_point(aes(color = interaction_label, alpha = most_non_additive, size = most_non_additive), stroke = 0), dpi = 600) +
geom_abline(slope = 1, intercept = 0, linetype = "dashed", alpha = 0.3) +
geom_vline(xintercept = c(lower_thres, upper_thres), linewidth = 0.2) +
scale_x_continuous(expand = expansion(add = 0), breaks = c(-0.8, 0, 0.8)) +
scale_y_continuous(expand = expansion(add = 0), breaks = c(-1, 0, 1)) +
scale_color_manual(values = non_additive_colors,
labels = c("Additive", "Other", "Buffering", "Synergy", "Cryptic"), drop = TRUE) +
scale_alpha_manual(values = c("TRUE" = 1, "FALSE" = 0.1)) +
scale_size_manual(values = c("TRUE" = 0.6, "FALSE" = 0.1)) +
ggh4x::facet_nested_wrap(vars(approach, method), nrow = 1,
labeller = labeller(approach = as_labeller(approach_labels),
method = as_labeller(method_labels)),
nest_line = TRUE, strip = ggh4x::strip_nested(clip = "off")) +
coord_fixed(xlim = c(-1, 1), ylim = c(-1.4, 1.4)) +
labs(x = "Observed LFC over additive expectation",
y = "Predicted LFC over\nadditive expectation",
color = "") +
guides(color = guide_legend(override.aes = list(size = 2)), alpha = "none", size = "none") +
theme(panel.spacing.x = unit(2, "mm"), legend.position = "bottom")
pert_pred_comparison
approx2 <- function(x, y, ...){
data <- tibble({{x}}, {{y}})
tmp <- as_tibble(approx(data[[1]], data[[2]], ...))
colnames(tmp) <- colnames(data)
tmp
}
tp_fdp_prec_recall_data_pre <- inter_pred_dat %>%
tidylog::left_join(gene_response_label_df|> dplyr::select(gene_name, pert_group, seed,pert_same_dir,interaction_label = label), by = c("perturbation" = "pert_group", "gene_name", "seed")) |>
filter(method %in% names(method_labels)) %>%
filter(method != "additive_model") |>
mutate(method = factor(method, levels = names(method_labels))) %>%
left_join(approach_annot, by = "method") %>%
mutate(true_nonadditive = obs_minus_add > upper_thres | obs_minus_add < lower_thres) %>%
group_by(method, seed) %>%
arrange(desc(abs(pred_minus_add))) %>%
mutate(tp = cumsum(true_nonadditive),
fp = cumsum(! true_nonadditive)) %>%
mutate(fdp = fp / pmax(1, fp + tp),
fpr = fp / sum(! true_nonadditive),
precision = tp / (tp + fp),
recall = tp / sum(true_nonadditive)) %>%
arrange(fdp) %>%
mutate(tp = cummax(tp)) %>%
mutate(tpr = tp / sum(true_nonadditive)) |>
ungroup()
## left_join: added 2 columns (pert_same_dir, interaction_label)
## > rows only in x 0
## > rows only in dplyr::select(gene_resp.. ( 310,000)
## > matched rows 2,790,000
## > ===========
## > rows total 2,790,000
tp_fdp_data <- tp_fdp_prec_recall_data_pre %>%
group_by(method, seed) %>%
reframe(tmp = approx2(fdp, tpr, xout = seq(0, 1, length.out = 101), yleft = 0, yright = 1)) %>%
ungroup() %>%
unnest(tmp) %>%
summarize(tpr = mean(tpr), .by = c(method, fdp)) %>%
mutate(method = factor(method, levels = names(method_labels)))
## Warning: There were 40 warnings in `reframe()`.
## The first warning was:
## ℹ In argument: `tmp = approx2(...)`.
## ℹ In group 1: `method = no_change` and `seed = 1`.
## Caused by warning in `regularize.values()`:
## ! collapsing to unique 'x' values
## ℹ Run `dplyr::last_dplyr_warnings()` to see the 39 remaining warnings.
colors_adapted <- approach_annot |>
mutate(color = ggplot_colors_five[method]) |>
mutate(label = paste0(method_labels[as.character(method)], "|", approach_labels[as.character(approach)])) |>
mutate(label = fct_reorder(label, as.integer(approach) * 1000 + as.integer(method) ))
label_pos <- c(no_change = 0.4, scgpt = 0.25, uce = 0.25, scfoundation = 0.5,
gears = 0.4, geneformer = 0.3, cpa = 0.7, scbert = 0.32)
tp_fdp_pl <- tp_fdp_data %>%
left_join(approach_annot, by = "method") %>%
mutate(approach = fct_relevel(approach, "baseline", "deep_learning", "foundation_model")) %>%
mutate(label = paste0(method_labels[as.character(method)], "|", approach_labels[as.character(approach)])) |>
mutate(label = fct_reorder(label, as.integer(approach) * 1000 + as.integer(method) )) %>%
ggplot(aes(x = fdp, y = tpr)) +
geom_line(aes(color = label)) +
ggrepel::geom_text_repel(data = . %>% mutate(annot = ifelse(rank((fdp - label_pos[as.character(method)])^2, ties.method = "first") == 1, method_labels[as.character(method)], ""), .by = method),
aes(label = annot), size = font_size_tiny / .pt,
min.segment.length = 0, box.padding = 0.5, point.padding = 0,
color = "black", bg.colour = "white", max.overlaps = Inf, seed = 1,
ylim = c(-Inf, Inf),
# label.size = NA, label.padding = unit(0.1, "mm")
) +
annotation_logticks(sides = "l", short = unit(0.3, "mm"), mid = unit(2/3, "mm"), long = unit(1, "mm")) +
scale_color_manual(values = deframe(colors_adapted[c("label", "color")])) +
scale_y_log10(breaks = c(0.01, 0.1, 1)) +
scale_x_continuous(expand = expansion(add = 0)) +
labs(x = "False discovery proportion ($\\frac{\\textrm{FP}}{\\textrm{FP}+\\textrm{TP}}$)",
y = "True Positive Rate ($\\frac{\\textrm{TP}}{\\textrm{TP}+\\textrm{FN}}$)",
color = "") +
coord_cartesian(ylim = c(0.001, 1), xlim = c(0, 1)) +
theme(legend.position = "none", legend.key.height=unit(0.1,"mm"))
tp_fdp_pl
prediction_label_vs_true_df <- tp_fdp_prec_recall_data_pre |>
slice_max(n = 500, order_by = abs(pred_minus_add), by = c(method)) |>
mutate(prediction_label = case_when(
baseline < ref & ref > value & value >= baseline ~ "buf",
baseline > ref & ref < value & value <= baseline ~ "buf",
baseline < ref & ref < value ~ "syn",
baseline > ref & ref > value ~ "syn",
baseline < ref & baseline > value ~ "anti",
baseline > ref & baseline < value ~ "anti",
.default = "other"
)) |>
count(method, interaction_label, prediction_label) |>
complete(method, prediction_label, interaction_label, fill = list(n = 0)) |>
summarize(n = median(n), .by = c(method, prediction_label, interaction_label)) |>
mutate(frac = n / sum(n), .by = c(prediction_label, method))
prediction_label_vs_true_df |> filter(method =="geneformer" & prediction_label == "buf")
## # A tibble: 5 × 5
## method prediction_label interaction_label n frac
## <fct> <chr> <fct> <int> <dbl>
## 1 geneformer buf Additive 54 0.286
## 2 geneformer buf Other 16 0.0847
## 3 geneformer buf Buffering 111 0.587
## 4 geneformer buf Synergy 8 0.0423
## 5 geneformer buf Cryptic 0 0
mosaic_plot <- prediction_label_vs_true_df |>
mutate(marginal_n = sum(n), .by = c(method, prediction_label)) |>
mutate(marginal_frac = marginal_n / sum(marginal_n), .by = c(method, interaction_label)) |>
mutate(interaction_label = fct_relevel(interaction_label, "Other", "Additive", "Buffering", "Synergy")) |>
ggplot(aes(x = frac, y = prediction_label)) +
geom_col(aes(fill = interaction_label, width = ifelse(marginal_frac == 0, 0, pmax(0.1, 1.5 * marginal_frac)))) +
shadowtext::geom_shadowtext(data = . %>% distinct(method, prediction_label, marginal_n),
aes(label = paste0("n=", scales::label_comma()(round(marginal_n))),
# x = ifelse(marginal_n == 0, 0, Inf), hjust = ifelse(marginal_n == 0, 0, 0)),
x = 0.02, hjust = 0), color = "black", bg.colour = "white",
vjust = 0.5, size = font_size_tiny / .pt, nudge_x = 0.01) +
scale_x_continuous(labels = \(x) paste0(x * 100, "\\%"), breaks = c(0, 0.5, 1), expand = expansion(add = 0), position = "bottom") +
scale_y_discrete(labels = c(buf = "Buffering", syn = "Synergy", anti = "Opposite")) +
scale_fill_manual(values = non_additive_colors) +
facet_wrap(vars(method), labeller = labeller(method = as_labeller(method_labels)), nrow = 1, scales = "fixed") +
labs(y = "Prediction class", x = "Proportion of observed interaction classes") +
guides(fill = "none") +
coord_cartesian(ylim = c(0.4, 3.2), xlim = c(0,1), clip = "off", expand = FALSE) +
theme(panel.spacing.x = unit(5, "mm"), panel.spacing.y = unit(1.5, "mm"),
strip.text = element_text(margin = margin(2,2,b=0, 2, "pt"))
)
## Warning in geom_col(aes(fill = interaction_label, width = ifelse(marginal_frac
## == : Ignoring unknown aesthetics: width
mosaic_plot
## Warning: Removed 40 rows containing missing values or values outside the scale
## range (`geom_col()`).
Make ROC and PRC
auprc_dat_labels <- tp_fdp_prec_recall_data_pre %>%
arrange(precision) %>%
summarize(auprc = -sum(zoo::rollmean(precision, k = 2) * diff(recall)),
.by = c(method, seed)) %>%
summarize(mean = mean(auprc),
se = sd(auprc) / sqrt(n()),
.by = method) %>%
arrange(-mean) %>%
transmute(method, label = paste0(method_labels[method], " ($", round(mean, digits = 2), "\\pm", round(se, digits = 2), "$)")) %>%
deframe()
auc_dat_labels <- tp_fdp_prec_recall_data_pre %>%
arrange(fpr) %>%
summarize(auc = sum(zoo::rollmean(recall, k = 2) * diff(fpr)),
.by = c(method, seed)) %>%
summarize(mean = mean(auc),
se = sd(auc) / sqrt(n()),
.by = method) %>%
arrange(-mean) %>%
transmute(method, label = paste0(method_labels[method], " ($", round(mean, digits = 2), "\\pm", round(se, digits = 2), "$)")) %>%
deframe()
prc_plot <- tp_fdp_prec_recall_data_pre %>%
mutate(method = factor(method, levels = names(auprc_dat_labels))) %>%
ggplot(aes(x = recall, y = precision)) +
ggrastr::rasterize(geom_line(aes(color = method), linewidth = 0.2), dpi = 300) +
scale_color_manual(values = ggplot_colors_five, labels = auprc_dat_labels) +
scale_x_continuous(breaks = c(0, 0.25, 0.5, 0.75, 1), labels = as.character(c(0, 0.25, 0.5, 0.75, 1))) +
facet_wrap(vars(seed), labeller = label_both, nrow = 1) +
coord_fixed() +
guides(color = guide_legend(override.aes = list(linewidth = 0.8))) +
labs(y = "Precision ($\\frac{\\textrm{TP}}{\\textrm{TP} + \\textrm{FP}}$)",
x = "Recall ($\\textrm{TPR} = \\frac{\\textrm{TP}}{\\textrm{TP} + \\textrm{FN}}$)",
color = "",
title = "(A) Precision-Recall Curve (PRC)")
roc_plot <- tp_fdp_prec_recall_data_pre %>%
mutate(method = factor(method, levels = names(auc_dat_labels))) %>%
ggplot(aes(x = fpr, y = recall)) +
ggrastr::rasterize(geom_line(aes(color = method), linewidth = 0.2), dpi = 300) +
geom_abline(slope = 1, color = "lightgrey", linewidth = 0.8, linetype = "dashed") +
scale_color_manual(values = ggplot_colors_five, labels = auc_dat_labels) +
scale_x_continuous(breaks = c(0, 0.25, 0.5, 0.75, 1), labels = as.character(c(0, 0.25, 0.5, 0.75, 1))) +
facet_wrap(vars(seed), labeller = label_both, nrow = 1) +
coord_fixed() +
guides(color = guide_legend(override.aes = list(linewidth = 0.8))) +
labs(x = "False Positive Rate ($\\frac{\\textrm{FP}}{\\textrm{FP} + \\textrm{TN}}$)",
y = "Recall ($\\textrm{TPR} = \\frac{\\textrm{TP}}{\\textrm{TP} + \\textrm{FN}}$)",
color = "",
title = "(B) Receiver Operator Curve (ROC)")
prc_plot
roc_plot
plot_assemble(
add_plot(prc_plot, x = 0, y = 0, width = 180, height = 50),
add_plot(roc_plot, x = 0, y = 52, width = 180, height = 50),
width = 180, height = 105, units = "mm", show_grid_lines = FALSE,
latex_support = TRUE, filename = "../plots/suppl-roc_curves.pdf"
)
## Measuring dimensions of: No Change ($0.33\pm0.03$)
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a55df0b714' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: scGPT ($0.29\pm0.03$)
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a56683b23a' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: UCE* ($0.29\pm0.03$)
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a5502a326a' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: scBERT* ($0.29\pm0.03$)
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a5253f644' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: CPA ($0.11\pm0.02$)
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a5564ceb8d' 'tikzStringWidthCalc.tex'
## gg[gg1]
## Measuring dimensions of: scBERT* ($0.81\pm0.01$)
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a553fca23e' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: UCE* ($0.81\pm0.01$)
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a56effbf85' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: Geneformer* ($0.76\pm0.02$)
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a55876edb1' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: scFoundation ($0.72\pm0.01$)
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a567ef2cd6' 'tikzStringWidthCalc.tex'
## gg[gg2]
## [1] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
cartoon_color_values <- as.character(wesanderson::wes_palette("FrenchDispatch", n = 3)[c(3,2,1)]) |>
magrittr::set_names(c("ctrl", "Ad", "Bd"))
height_spec <- c(ctrl = 0.4, Ad = 0.1, Bd = 0.15)
training_df <- tibble(bar = c("ctrl", "A", "A", "B", "B"),
part = c("ctrl", "ctrl", "Ad", "ctrl", "Bd")) |>
mutate(height = height_spec[part])
train_plot <- training_df |>
mutate(bar = factor(bar, c("ctrl", "A", "B"))) |>
mutate(part = factor(part, c( "Bd", "Ad", "ctrl"))) |>
arrange(bar, desc(part)) |>
mutate(start_pos = cumsum(lag(height, default = 0)),
end_pos = cumsum(height),
.by = bar) |>
ggplot(aes(x = bar, y = height)) +
geom_col(aes(fill = part), width = 0.5) +
scale_y_continuous(expand = expansion(add = 0)) +
scale_x_discrete(labels = c(ctrl = "No Perturbation", A = "Perturbation A", B = "Perturbation B")) +
scale_fill_manual(values = cartoon_color_values) +
scale_color_manual(values = cartoon_color_values) +
coord_cartesian(ylim = c(0, 0.8), clip = "off") +
guides(color = "none", fill = "none", alpha = "none",
x = guide_axis(angle = 90)) +
labs(y = "Log Expression of a single gene",
subtitle = "Training data") +
theme(axis.title = element_blank(),
axis.ticks.y.left = element_blank(),
axis.text.y.left = element_blank())
categories_df <- tibble(add = sum(height_spec),
syn = add + 0.1,
buf = add - 0.12,
opp = height_spec["ctrl"] - 0.1) |>
pivot_longer(everything(), names_to = "class", values_to = "height") |>
mutate(class = factor(class, levels = c("add", "buf", "syn", "opp")))
arrow_df <- enframe(height_spec, name = "part", value = "height") |>
mutate(start_pos = cumsum(lag(height, default = 0)),
end_pos = cumsum(height)) |>
mutate(class = list(categories_df$class)) |>
unnest(class)
sum_height_spec1 <- sum(height_spec)
obs_plot <- categories_df |>
ggplot(aes(x = class, y = height)) +
geom_col(fill = "#5e5e5e", width = 0.5) +
annotate(geom = "errorbar", x = "add", ymin = sum_height_spec1 - 0.05, ymax = sum_height_spec1 + 0.05, width = 0.3) +
geom_segment(data = arrow_df, aes(color = part, y = start_pos, yend = end_pos), linewidth = 0.7,
arrow = grid::arrow(type = "closed", length = unit(1.5, "mm"))) +
scale_y_continuous(expand = expansion(add = 0),
sec.axis = sec_axis(transform = \(x) x - sum_height_spec1,
name = "LFC over additive expectation", breaks = c(0))) +
scale_x_discrete(labels = c(add = "Additive\n(Non-interaction)", buf = "Buffering", syn = "Synergy", opp = "Opposite")) +
scale_fill_manual(values = cartoon_color_values) +
scale_color_manual(values = cartoon_color_values) +
coord_cartesian(ylim = c(0, 0.8), clip = "off") +
guides(color = "none", fill = "none", alpha = "none",
x = guide_axis(angle = 90)) +
labs(subtitle = "Observed / Predicted Interaction\nClass of Perturbation A+B") +
theme(axis.title = element_blank(),
axis.ticks.y.left = element_blank(),
axis.text.y.left = element_blank(),
axis.line.y.left = element_blank())
cowplot::plot_grid(train_plot, obs_plot, align = "h")
height_spec <- c(ctrl = 0.4, Ad = -0.1, Bd = -0.15)
training_df <- tibble(bar = c("ctrl", "A", "A", "B", "B"),
part = c("ctrl", "ctrl", "Ad", "ctrl", "Bd")) |>
mutate(height = height_spec[part]) |>
mutate(bar = factor(bar, c("ctrl", "A", "B"))) |>
mutate(part = factor(part, c( "Bd", "Ad", "ctrl"))) |>
arrange(bar, desc(part)) |>
mutate(start_pos = cumsum(lag(height, default = 0)),
end_pos = cumsum(height),
.by = bar)
train_plot2 <- training_df |>
mutate(end_pos = min(end_pos), .by = bar) |>
ggplot(aes(x = bar)) +
geom_tile(aes(fill = part, alpha = part, y = (start_pos + end_pos) / 2, height = abs(end_pos - start_pos)), width = 0.5) +
annotate("segment", x = "A", y = height_spec["ctrl"], yend = height_spec["ctrl"] + height_spec["Ad"], color = cartoon_color_values["Ad"],
linewidth = 0.7, arrow = grid::arrow(type = "closed", length = unit(1.5, "mm"))) +
annotate("segment", x = "B", y = height_spec["ctrl"], yend = height_spec["ctrl"] + height_spec["Bd"], color = cartoon_color_values["Bd"],
linewidth = 0.7, arrow = grid::arrow(type = "closed", length = unit(1.5, "mm"))) +
scale_y_continuous(expand = expansion(add = 0)) +
scale_x_discrete(labels = c(ctrl = "No Perturbation", A = "Perturbation A", B = "Perturbation B")) +
scale_fill_manual(values = cartoon_color_values) +
scale_color_manual(values = cartoon_color_values) +
scale_alpha_manual(values = c(ctrl = 1, Ad = 0.2, Bd = 0.2)) +
coord_cartesian(ylim = c(0, 0.5), clip = "off") +
guides(color = "none", fill = "none", alpha = "none",
x = guide_axis(angle = 90)) +
labs(y = "Log Expression of a single gene") +
theme(axis.title = element_blank(),
axis.ticks.y.left = element_blank(),
axis.text.y.left = element_blank())
categories_df <- tibble(add = sum(height_spec),
syn = add - 0.1,
buf = add + 0.12,
opp = height_spec["ctrl"] + 0.03) |>
pivot_longer(everything(), names_to = "class", values_to = "height") |>
mutate(class = factor(class, levels = c("add", "buf", "syn", "opp")))
arrow_df <- enframe(height_spec, name = "part", value = "height") |>
mutate(start_pos = cumsum(lag(height, default = 0)),
end_pos = cumsum(height)) |>
mutate(class = list(categories_df$class)) |>
unnest(class)
sum_height_spec2 <- sum(height_spec)
obs_plot2 <- categories_df |>
ggplot(aes(x = class, y = height)) +
geom_col(fill = "#5e5e5e", width = 0.5) +
annotate(geom = "errorbar", x = "add", ymin = sum_height_spec2 - 0.05, ymax = sum_height_spec2 + 0.05, width = 0.3) +
geom_segment(data = arrow_df %>% filter(part != "ctrl"), aes(color = part, y = start_pos, yend = end_pos), linewidth = 0.7,
arrow = grid::arrow(type = "closed", length = unit(1.5, "mm"))) +
scale_y_continuous(expand = expansion(add = 0),
sec.axis = sec_axis(transform = \(x) x - sum_height_spec2,
name = "LFC over additive expectation", breaks = c(0))) +
scale_x_discrete(labels = c(add = "Additive\n(Non-interaction)", buf = "Buffering", syn = "Synergy", opp = "Opposite")) +
scale_fill_manual(values = cartoon_color_values) +
scale_color_manual(values = cartoon_color_values) +
coord_cartesian(ylim = c(0, 0.5), clip = "off") +
guides(color = "none", fill = "none", alpha = "none",
x = guide_axis(angle = 90)) +
theme(axis.title = element_blank(),
axis.ticks.y.left = element_blank(),
axis.text.y.left = element_blank(),
axis.line.y.left = element_blank())
cowplot::plot_grid(train_plot2, obs_plot2, align = "h")
cartoon_pl <- cowplot::plot_grid(
cowplot::plot_grid(train_plot + theme(axis.text.x = element_blank()), obs_plot + theme(axis.text.x = element_blank()), rel_widths = c(3, 4), align = "h"),
cowplot::plot_grid(train_plot2, obs_plot2, rel_widths = c(3, 4), align = "h"),
ncol = 1, rel_heights = c(1, 1.5), align = "v"
)
cartoon_pl
plot_assemble(
add_text("(A) Double perturbation prediction error", x = 2.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(main_pl_double_l2, x = 0, y = 4, width = 125, height = 60),
add_text("(B) Example:\\;{\\scriptsize\\textcolor{baseROrange}\\faCircle}\\;CEBPE+CEBPB", x = 128, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(obs_pred_corr_pl, x = 125, y = 4, width = 58, height = 60),
add_text("(C) Accuracy of interaction predictions", x = 2.7, y = 66, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(tp_fdp_pl, x = 0, y = 70, width = 60, height = 58),
add_text("(D) Classification of interactions", x = 62.7, y = 66, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(cartoon_pl, x = 65, y = 70, width = 65, height = 60),
add_text("Log Expression of a two example genes", x = 64.5, y = 93, angle = 90, fontsize = font_size_small, vjust = 0.5, hjust = 0.5),
add_text("LFC over additive expectation", x = 130, y = 93, angle = -90, fontsize = font_size_small, vjust = 0.5, hjust = 0.5),
add_text("(E) Observed composition\nof interaction classes", x = 141, y = 66, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(non_add_pl1, x = 139, y = 76, width = 20, height = 45),
add_plot(grid::polygonGrob(x = c(0.398, 0.65, 0.65, 0.398), y = c(0.969, 0.969, 0.05, 0.927),
gp = grid::gpar(fill = non_additive_colors["Non-additive"], alpha = 0.2, lty = 0)),
x = 140, y = 76, width = 37, heigh = 45),
add_plot(non_add_pl2, x = 156, y = 76, width = 20, height = 45),
add_text("(F) Prediction of LFC over additive expectation and interaction class", x = 2.7, y = 131.5, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(pert_pred_comparison + guides(color = "none"), x = 4, y = 135, width = 176, height = 40),
add_plot(mosaic_plot + theme(strip.text = element_blank(), axis.text.x = element_text(size = font_size_tiny)), x = 0, y = 173, width = 180, height = 20),
add_plot(my_get_legend(pert_pred_comparison), x = 125, y = 131.5, width = 50, height = 5),
width = 180, height = 193, units = "mm", show_grid_lines = FALSE,
latex_support = TRUE, filename = "../plots/perturbation_prediction.pdf"
)
## gg[gg1]
## Warning: Removed 147 rows containing missing values or values outside the scale
## range (`position_quasirandom()`).
## gg[gg2]
## gg[gg3]
## Measuring dimensions of: $L_2$: 8.5
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a556902e91' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: $L_2$: 4.7
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a513e15dfd' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: $L_2$: 7.1
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a530e1973d' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: $L_2$: 4.4
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a52a903edd' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: $L_2$: 6.4
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a5640f34ff' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: $L_2$: 2.9
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a5226084ab' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: $L_2$: 6.5
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a56ea6082e' 'tikzStringWidthCalc.tex'
## Measuring dimensions of: $L_2$: 14.8
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a556633cc2' 'tikzStringWidthCalc.tex'
## gg[gg4]
## Measuring dimensions of: \bfseries{}(C) Accuracy of interaction predictions
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a5d2a10dd' 'tikzStringWidthCalc.tex'
## gg[gg5]
## gg[gg6]
## gg[gg7]
## Measuring dimensions of: (Non-interaction)
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a544b923eb' 'tikzStringWidthCalc.tex'
## gg[gg8]
## gg[gg9]
## gg[gg10]
## gg[gg11]
## gg[gg12]
## gg[gg13]
## gg[gg14]
## Measuring dimensions of: \bfseries{}(F) Prediction of LFC over additive expectation and interaction class
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a556e53c8c' 'tikzStringWidthCalc.tex'
## gg[gg15]
## gg[gg16]
## Warning: Removed 40 rows containing missing values or values outside the scale
## range (`geom_col()`).
## gg[gg17]
## gg[gg18]
## [1] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
## [16] TRUE TRUE TRUE
overlap_preparation_df <- tp_fdp_prec_recall_data_pre |>
slice_head(n = 500, by = c(method)) |>
summarize(.by = c(perturbation, gene_name, seed, true_nonadditive, interaction_label),
methods = list(method_labels[method])) |>
mutate(shared_deep_learning_exclusives = lengths(methods) >= 3 & map_lgl(methods, \(x) ! "No Change" %in% x))
upset_plot_pred_overlap <- overlap_preparation_df |>
ggplot(aes(x = methods)) +
geom_bar(aes(fill = shared_deep_learning_exclusives), show.legend = FALSE) +
scale_y_continuous(expand = expansion(add = 0)) +
ggupset::scale_x_upset(sets = unname(method_labels), n_intersections = 25) +
theme(axis.title.x = element_blank()) +
ggupset::theme_combmatrix(combmatrix.panel.point.size = 1,
combmatrix.panel.line.size = 0.8)
upset_plot_pred_overlap
## Warning: Removed 265 rows containing non-finite outside the scale range
## (`stat_count()`).
tab1 <- overlap_preparation_df |>
filter(shared_deep_learning_exclusives) |>
count(gene_name) |>
slice_max(n, n = 10, with_ties = FALSE) |>
dplyr::rename(`Gene Name` = gene_name)
tab2 <- overlap_preparation_df |>
filter(shared_deep_learning_exclusives) |>
count(perturbation) |>
slice_max(n, n = 10, with_ties = FALSE) |>
dplyr::rename(`Perturbation` = perturbation)
shared_deep_learning_exclusives_tab <- cbind(tab1, tab2) |>
tinytable::tt() |>
tinytable::style_tt(fontsize = 0.8) |>
tinytable::style_tt(j = 2, line = "r") |>
tinytable:::build_tt(output = "latex") %>%
{
str <- (.@table_string) |>
str_remove_all("%%.*\n") |>
str_remove(r"(\\begin\{table\})") |>
str_remove(r"(\\end\{table\})")
paste0("\\begin{minipage}{\\textwidth}", str, "\\end{minipage}")
}
cat(shared_deep_learning_exclusives_tab)
## \begin{minipage}{\textwidth}
## \centering
## \begin{tblr}[ ] { colspec={Q[]Q[]Q[]Q[]},
## column{1,2,3,4}={}{font=\fontsize{0.8em}{1.1em}\selectfont,},
## vline{3}={1,2,3,4,5,6,7,8,9,10,11}{solid, black, 0.1em},
## } \toprule
## Gene Name & n & Perturbation & n \\ \midrule HBZ & 59 & CEBPE+CEBPA & 36 \\
## HBG2 & 45 & CEBPB+CEBPA & 34 \\
## GYPB & 15 & CEBPB+MAPK1 & 16 \\
## SH3BGRL3 & 13 & CEBPE+CEBPB & 8 \\
## TMSB10 & 13 & JUN+CEBPA & 7 \\
## GYPA & 12 & ZC3HAV1+CEBPE & 7 \\
## HBG1 & 12 & AHR+FEV & 6 \\
## VIM & 8 & ETS2+MAPK1 & 6 \\
## CYBA & 5 & CDKN1C+CDKN1B & 5 \\
## RANBP1 & 5 & CEBPE+RUNX1T1 & 5 \\
## \bottomrule
## \end{tblr}
## \end{minipage}
plot_assemble(
add_text("(A) Overlap of interaction predictions", x = 2.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(upset_plot_pred_overlap + labs(subtitle="Set of predictions where at least 3 DL tools predict an interaction and \\emph{no change} does not (blue bars)"),
x = 0, y = 4.5, width = 110, height = 60),
add_text("(B) Deep Learning interactions not\n\\;\\;found by \\emph{no change}", x = 120, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
add_text(str_replace_all(shared_deep_learning_exclusives_tab, "\n", " "), x = 100, y = 34, fontsize = font_size, vjust = 1, hjust = 0),
width = 180, height = 65, units = "mm", show_grid_lines = FALSE,
latex_support = TRUE, filename = "../plots/response-deep_learning_specific_calls.pdf"
)
## Measuring dimensions of: \bfseries{}(A) Overlap of interaction predictions
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a532cc807d' 'tikzStringWidthCalc.tex'
## gg[gg1]
## Warning: Removed 265 rows containing non-finite outside the scale range
## (`stat_count()`).
## Measuring dimensions of: Set of predictions where at least 3 DL tools predict an interaction and \emph{no change} does not (blue bars)
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a51007a899' 'tikzStringWidthCalc.tex'
## gg[gg2]
## gg[gg3]
## Measuring dimensions of: \begin{minipage}{\textwidth} \centering \begin{tblr}[ ] { colspec={Q[]Q[]Q[]Q[]}, column{1,2,3,4}={}{font=\fontsize{0.8em}{1.1em}\selectfont,}, vline{3}={1,2,3,4,5,6,7,8,9,10,11}{solid, black, 0.1em}, } \toprule Gene Name & n & Perturbation & n \\ \midrule HBZ & 59 & CEBPE+CEBPA & 36 \\ HBG2 & 45 & CEBPB+CEBPA & 34 \\ GYPB & 15 & CEBPB+MAPK1 & 16 \\ SH3BGRL3 & 13 & CEBPE+CEBPB & 8 \\ TMSB10 & 13 & JUN+CEBPA & 7 \\ GYPA & 12 & ZC3HAV1+CEBPE & 7 \\ HBG1 & 12 & AHR+FEV & 6 \\ VIM & 8 & ETS2+MAPK1 & 6 \\ CYBA & 5 & CDKN1C+CDKN1B & 5 \\ RANBP1 & 5 & CEBPE+RUNX1T1 & 5 \\ \bottomrule \end{tblr} \end{minipage}
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a566cddd07' 'tikzStringWidthCalc.tex'
## gg[gg4]
gene_response_label_df|>
dplyr::select(gene_name, pert_group, seed,pert_same_dir,interaction_label = label)
## # A tibble: 620,000 × 5
## gene_name pert_group seed pert_same_dir interaction_label
## <chr> <chr> <int> <lgl> <fct>
## 1 ABCF1 AHR+FEV 1 TRUE Additive
## 2 ABRACL AHR+FEV 1 FALSE Additive
## 3 ABT1 AHR+FEV 1 FALSE Additive
## 4 ACP1 AHR+FEV 1 FALSE Additive
## 5 ACTB AHR+FEV 1 FALSE Additive
## 6 ACTG1 AHR+FEV 1 TRUE Additive
## 7 ACTN4 AHR+FEV 1 TRUE Additive
## 8 ADRM1 AHR+FEV 1 FALSE Additive
## 9 AK2 AHR+FEV 1 TRUE Additive
## 10 ALDOA AHR+FEV 1 FALSE Additive
## # ℹ 619,990 more rows
top_gene_reoccurence <- inter_pred_dat %>%
tidylog::inner_join(gene_response_label_df|> dplyr::select(gene_name, pert_group, seed,pert_same_dir,interaction_label = label), by = c("perturbation" = "pert_group", "gene_name", "seed")) %>%
tidylog::filter(pert_same_dir) %>%
(\(x){
bind_rows(x,
x %>% filter(method == "no_change") %>%
transmute(method = "ground_truth", seed, perturbation, gene_name, pred_minus_add = obs_minus_add))
}) %>%
dplyr::select(seed, method, perturbation, gene_name, pred_minus_add) %>%
group_by(method, seed) %>%
slice_max(abs(pred_minus_add), n = 100, with_ties = FALSE) %>%
ungroup() %>%
mutate(gene_name = fct_infreq(gene_name)) %>%
mutate(gene_name = fct_other(gene_name, keep = levels(gene_name)[1:4])) %>%
count(seed, method, gene_name)
## inner_join: added 2 columns (pert_same_dir, interaction_label)
## > rows only in x ( 0)
## > rows only in dplyr::select(gene_resp.. ( 310,000)
## > matched rows 2,790,000
## > ===========
## > rows total 2,790,000
## filter: removed 1,056,897 rows (38%), 1,733,103 rows remaining
top_perturbation_reoccurence <- inter_pred_dat %>%
tidylog::inner_join(gene_response_label_df|> dplyr::select(gene_name, pert_group, seed,pert_same_dir,interaction_label = label), by = c("perturbation" = "pert_group", "gene_name", "seed")) %>%
tidylog::filter(pert_same_dir) %>%
(\(x){
bind_rows(x,
x %>% filter(method == "no_change") %>%
transmute(method = "ground_truth", seed, perturbation, gene_name, pred_minus_add = obs_minus_add))
}) %>%
dplyr::select(seed, method, perturbation, gene_name, pred_minus_add) %>%
group_by(method, seed) %>%
slice_max(abs(pred_minus_add), n = 100, with_ties = FALSE) %>%
ungroup() %>%
mutate(perturbation = fct_infreq(perturbation)) %>%
mutate(perturbation = fct_other(perturbation, keep = levels(perturbation)[1:4])) %>%
count(seed, method, perturbation)
## inner_join: added 2 columns (pert_same_dir, interaction_label)
## > rows only in x ( 0)
## > rows only in dplyr::select(gene_resp.. ( 310,000)
## > matched rows 2,790,000
## > ===========
## > rows total 2,790,000
## filter: removed 1,056,897 rows (38%), 1,733,103 rows remaining
ggplot_colors_six <- colorspace::qualitative_hcl(4, h = c(0, 270), c = 60, l = 70)
top_gene_reoccurence_plot <- top_gene_reoccurence %>%
filter(method != "additive_model") |>
mutate(method = factor(method, levels = c("ground_truth", names(method_labels)))) %>%
ggplot(aes(x = method, y = n)) +
geom_col(aes(fill = gene_name)) +
scale_fill_manual(values = c(ggplot_colors_six, "grey")) +
scale_x_discrete(labels = c("ground_truth" = "Ground Truth", method_labels)) +
scale_y_continuous(expand = expansion(add = 0)) +
facet_grid(vars(), vars(seed), labeller = label_both) +
guides(x = guide_axis(angle = 90)) +
labs(x = "", y = "No. occurrences", fill = "")
top_perturbation_reoccurence_plot <- top_perturbation_reoccurence %>%
filter(method != "additive_model") |>
mutate(method = factor(method, levels = c("ground_truth", names(method_labels)))) %>%
ggplot(aes(x = method, y = n)) +
geom_col(aes(fill = perturbation)) +
scale_fill_manual(values = c(ggplot_colors_six, "grey")) +
scale_x_discrete(labels = c("ground_truth" = "Ground Truth", method_labels)) +
scale_y_continuous(expand = expansion(add = 0)) +
facet_grid(vars(), vars(seed), labeller = label_both) +
guides(x = guide_axis(angle = 90)) +
labs(x = "", y = "No. occurrences", fill = "")
top_gene_reoccurence_plot
top_perturbation_reoccurence_plot
top_gene_reoccurence %>%
filter(method != "ground_truth" & method != "additive_model") %>%
summarize(top_six = sum(n[gene_name != "Other"]), .by = c(seed, method)) %>%
pull(top_six) %>% summary()
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 16.00 39.00 53.00 47.88 58.25 64.00
top_perturbation_reoccurence %>%
filter(method == "ground_truth" & method != "additive_model") %>%
summarize(top_six = sum(n[perturbation != "Other"]), .by = c(seed, method)) %>%
pull(top_six) %>% summary()
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 7 55 55 48 59 64
plot_assemble(
add_text("(A) Reoccuring genes among top 100 interaction predictions", x = 2.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(top_gene_reoccurence_plot, x = 0, y = 4, width = 180, height = 65),
add_text("(B) Reoccuring perturbations among top 100 interaction predictions", x = 2.7, y = 70, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(top_perturbation_reoccurence_plot, x = 0, y = 73, width = 180, height = 65),
width = 180, height = 140, units = "mm", show_grid_lines = FALSE,
latex_support = TRUE, filename = "../plots/suppl-non_additive_gene_reoccurence.pdf"
)
## Measuring dimensions of: \bfseries{}(A) Reoccuring genes among top 100 interaction predictions
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a520efc31d' 'tikzStringWidthCalc.tex'
## gg[gg1]
## gg[gg2]
## Measuring dimensions of: \bfseries{}(B) Reoccuring perturbations among top 100 interaction predictions
## Running command: '/Library/TeX/texbin/pdflatex' -interaction=batchmode -halt-on-error -output-directory '/var/folders/f4/z30qmj3j55db5zh17mq85tr80000gq/T//RtmpQNSStN/tikzDevice172a55cf2b5cf' 'tikzStringWidthCalc.tex'
## gg[gg3]
## gg[gg4]
plot_data <- inter_pred_dat %>%
tidylog::inner_join(gene_response_label_df|> dplyr::select(gene_name, pert_group, seed,pert_same_dir,interaction_label = label), by = c("perturbation" = "pert_group", "gene_name", "seed")) %>%
tidylog::filter(pert_same_dir) %>%
filter(seed == 1) %>%
mutate(rank = rank(-abs(pred_minus_add)), .by = method) %>%
filter(gene_name %in% c("HBG2", "HBZ")) %>%
mutate(gene_name = factor(gene_name, levels = c("HBG2", "HBZ"))) %>%
filter(method != "additive_model")
## inner_join: added 2 columns (pert_same_dir, interaction_label)
## > rows only in x ( 0)
## > rows only in dplyr::select(gene_resp.. ( 310,000)
## > matched rows 2,790,000
## > ===========
## > rows total 2,790,000
## filter: removed 1,056,897 rows (38%), 1,733,103 rows remaining
pl1 <- plot_data |>
filter(gene_name == "HBG2") |>
mutate(perturbation = fct_reorder(perturbation, ref)) |>
mutate(true_nonadditive = obs_minus_add > upper_thres | obs_minus_add < lower_thres) %>%
ggplot(aes(x = perturbation)) +
geom_hline(aes(yintercept = baseline), alpha = 0.1) +
ungeviz::geom_hpline(aes(y = ref), color = "darkgrey", linewidth = 0.3, width = 0.8) +
geom_rect(aes(xmin = stage(perturbation, after_scale = x - 0.5), xmax = stage(perturbation, after_scale = x + 0.5), ymin = ref - lower_thres, ymax = ref + lower_thres),
fill = "#e2e2e2", alpha = 0.3) +
geom_point(aes(y = ground_truth, color = interaction_label), size = 1.6, stroke = 0) +
geom_point(aes(y = value), size = 0.6, stroke = 0, shape = "square") +
scale_color_manual(values = non_additive_colors, drop = TRUE) +
scale_alpha_manual(values = c("TRUE" = 1, "FALSE" = 0.2)) +
facet_wrap(vars(method), nrow = 3, labeller = as_labeller(method_labels)) +
guides(x = guide_axis(angle = 90), color = "none", alpha = "none") +
labs(y = "Expression of HBG2", x = "Double perturbation") +
theme(axis.text.x = element_text(size = 4))
pl2 <- plot_data |>
filter(gene_name == "HBZ") |>
mutate(perturbation = fct_reorder(perturbation, ref)) |>
mutate(true_nonadditive = obs_minus_add > upper_thres | obs_minus_add < lower_thres) %>%
ggplot(aes(x = perturbation)) +
geom_hline(aes(yintercept = baseline), alpha = 0.1) +
ungeviz::geom_hpline(aes(y = ref), color = "darkgrey", linewidth = 0.3, width = 0.8) +
geom_rect(aes(xmin = stage(perturbation, after_scale = x - 0.5), xmax = stage(perturbation, after_scale = x + 0.5), ymin = ref - lower_thres, ymax = ref + lower_thres),
fill = "#e2e2e2", alpha = 0.3) +
geom_point(aes(y = ground_truth, color = interaction_label), size = 1.6, stroke = 0) +
geom_point(aes(y = value), size = 0.6, stroke = 0, shape = "square") +
scale_color_manual(values = non_additive_colors, drop = TRUE) +
scale_alpha_manual(values = c("TRUE" = 1, "FALSE" = 0.2)) +
facet_wrap(vars(method), nrow = 3, labeller = as_labeller(method_labels)) +
guides(x = guide_axis(angle = 90), color = "none", alpha = "none") +
labs(y = "Expression of HBZ", x = "Double perturbation") +
theme(axis.text.x = element_text(size = 4))
plot_assemble(
add_text("Analysis of the predicted and observed expression patterns for HBG2 and HBZ", x = 2.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
add_text(paste0("Comparison of the observed expression {\\scriptsize\\textcolor{nonAdditivePurple}\\faCircle}\\,{\\scriptsize\\textcolor{nonAdditiveGrey}\\faCircle}\\,{\\scriptsize\\textcolor{nonAdditiveOrange}\\faCircle} ",
"against predicted value {\\tiny\\faSquare} for each double perturbation. The grey box in the background shows the additive range."),
x = 2.7, y = 6, fontsize = font_size_small, vjust = 1),
add_plot(pl1, x = 0, y = 8, width = 180, height = 90),
add_plot(pl2, x = 0, y = 100, width = 180, height = 90),
add_plot(my_get_legend(pl2 + guides(color = guide_legend(title = "", direction = "horizontal", nrow = 1))),
x = 140, y = 170, width = 20, height = 10),
width = 180, height = 190, units = "mm", show_grid_lines = FALSE,
latex_support = TRUE, filename = "../plots/suppl-top_gene_plots.pdf"
)
## gg[gg1]
## gg[gg2]
## gg[gg3]
## gg[gg4]
## gg[gg5]
inter_pred_dat %>%
filter(seed == 2) %>%
group_by(method) %>%
slice_max(pred_minus_add, n = 100, with_ties = FALSE) %>%
count(gene_name) %>%
slice_max(n, n = 3)
## # A tibble: 126 × 3
## # Groups: method [9]
## method gene_name n
## <fct> <chr> <int>
## 1 no_change HBG2 9
## 2 no_change GAL 6
## 3 no_change HBZ 5
## 4 additive_model ABCF1 1
## 5 additive_model ABRACL 1
## 6 additive_model ABT1 1
## 7 additive_model ACP1 1
## 8 additive_model ACTB 1
## 9 additive_model ACTG1 1
## 10 additive_model ACTN4 1
## # ℹ 116 more rows
inter_pred_dat %>%
filter(method == "no_change") %>%
group_by(seed, method) %>%
slice_max(obs_minus_add, n = 100, with_ties = FALSE) %>%
count(perturbation) %>%
slice_max(n, n = 3)
## # A tibble: 17 × 4
## # Groups: seed, method [5]
## seed method perturbation n
## <int> <fct> <chr> <int>
## 1 1 no_change CEBPB+CEBPA 52
## 2 1 no_change CEBPE+KLF1 9
## 3 1 no_change FEV+CBFA2T3 7
## 4 2 no_change CEBPB+CEBPA 56
## 5 2 no_change CEBPE+KLF1 12
## 6 2 no_change CEBPE+CEBPB 6
## 7 2 no_change PTPN12+UBASH3A 6
## 8 3 no_change CEBPE+CEBPA 52
## 9 3 no_change CEBPE+KLF1 9
## 10 3 no_change FEV+CBFA2T3 7
## 11 3 no_change ZC3HAV1+CEBPE 7
## 12 4 no_change CEBPE+KLF1 21
## 13 4 no_change ZC3HAV1+CEBPE 19
## 14 4 no_change CEBPE+PTPN12 12
## 15 5 no_change CEBPB+CEBPA 35
## 16 5 no_change CEBPE+CEBPA 28
## 17 5 no_change SET+CEBPE 14
resource_df <- read_tsv("../benchmark/output/single_perturbation_jobs_stats.tsv")
## Rows: 608 Columns: 7
## ── Column specification ──────────────────────────────────────────────────────────────────────────────────────────────────
## Delimiter: "\t"
## chr (6): name, metric, gpu_logged, gpu_ask, node, gpu_available
## dbl (1): value
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
norman_df <- resource_df %>%
separate(name, into = c("dataset", "seed", "method"), sep = "-") %>%
filter(method %in% names(method_labels)) |>
filter(dataset == "norman_from_scfoundation") |>
mutate(gpu_available = str_remove(gpu_available, "gpu=")) |>
mutate(gpu_label = case_when(
gpu_available == "A40" ~ "NVIDIA A40",
gpu_available == "3090" ~ "NVIDIA RTX 3090",
gpu_available == "H100" ~ "NVIDIA H100",
gpu_available == "L40s" ~ "NVIDIA L40s",
is.na(gpu_available) ~ "No GPU",
.default = "Other GPU"
))
norman_df |>
count(gpu_logged, gpu_available, gpu_label)
## # A tibble: 7 × 4
## gpu_logged gpu_available gpu_label n
## <chr> <chr> <chr> <int>
## 1 GPU: NVIDIA A40 A40 NVIDIA A40 4
## 2 GPU: NVIDIA GeForce RTX 3090 3090 NVIDIA RTX 3090 20
## 3 GPU: NVIDIA H100 PCIe H100 NVIDIA H100 20
## 4 <NA> 3090 NVIDIA RTX 3090 56
## 5 <NA> A40 NVIDIA A40 36
## 6 <NA> L40s NVIDIA L40s 4
## 7 <NA> <NA> No GPU 20
mem_pl <- norman_df %>%
filter(metric == "max_mem_kbytes") %>%
filter(method != "ground_truth") %>%
mutate(method = factor(method, levels = names(method_labels))) %>%
ggplot(aes(x = method, y = value * 1000)) +
ggbeeswarm::geom_quasirandom(aes(color = gpu_label),width = 0.2, size = 0.4) +
scale_y_continuous(labels = scales::label_bytes()) +
scale_x_discrete(labels = method_labels) +
labs(y = "Peak memory usage (RAM)", x = "", color = "GPU") +
guides(x = guide_axis(angle = 90), color = guide_legend(override.aes = list(size = 1))) +
theme(panel.grid.major.y = element_line(color = "lightgrey", linewidth = 0.2))
dur_pl <- norman_df %>%
filter(metric == "elapsed") %>%
filter(method != "ground_truth") %>%
mutate(method = factor(method, levels = names(method_labels))) %>%
ggplot(aes(x = method, y = value)) +
ggbeeswarm::geom_quasirandom(aes(color = gpu_label), width = 0.2, size = 0.4) +
scale_y_log10(limits = c(60, NA), breaks = c(60, 10 * 60, 60 * 60, 6 * 60 * 60, 60 * 60 * 24, 3 * 60 * 60 * 24),
labels = c("1 min", "10 min", "1 hour", "6 hours", "1 day", "3 days")) +
scale_x_discrete(labels = method_labels) +
labs(y = "Duration", x = "", color = "GPU") +
guides(x = guide_axis(angle = 90)) +
theme(panel.grid.major.y = element_line(color = "lightgrey", linewidth = 0.2))
mem_pl
dur_pl
plot_assemble(
add_text("(A)", x = 2.7, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(dur_pl + guides(color = "none"), x = 0, y = 4, width = 76, height = 47.5),
add_text("(B)", x = 82, y = 1, fontsize = font_size, vjust = 1, fontface = "bold"),
add_plot(mem_pl, x = 82, y = 4, width = 98, height = 47.5),
width = 180, height = 52, units = "mm", show_grid_lines = FALSE,
latex_support = TRUE, filename = "../plots/suppl-resource_usage.pdf"
)
## gg[gg1]
## gg[gg2]
## gg[gg3]
## gg[gg4]
sessionInfo()
## R version 4.4.1 (2024-06-14)
## Platform: aarch64-apple-darwin20
## Running under: macOS Sonoma 14.6
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.12.0
##
## locale:
## [1] en_GB.UTF-8/en_GB.UTF-8/en_GB.UTF-8/C/en_GB.UTF-8/en_GB.UTF-8
##
## time zone: Europe/London
## tzcode source: internal
##
## attached base packages:
## [1] stats graphics grDevices datasets utils methods base
##
## other attached packages:
## [1] glue_1.7.0 lubridate_1.9.3 forcats_1.0.0 stringr_1.5.1
## [5] dplyr_1.1.4 purrr_1.0.2 readr_2.1.5 tidyr_1.3.1
## [9] tibble_3.2.1 ggplot2_3.5.1 tidyverse_2.0.0
##
## loaded via a namespace (and not attached):
## [1] RColorBrewer_1.1-3 strawr_0.0.92
## [3] jsonlite_1.8.8 shape_1.4.6.1
## [5] magrittr_2.0.3 ggbeeswarm_0.7.2
## [7] farver_2.1.2 rmarkdown_2.27
## [9] GlobalOptions_0.1.2 fs_1.6.4
## [11] BiocIO_1.14.0 zlibbioc_1.50.0
## [13] vctrs_0.6.5 locfdr_1.1-8
## [15] memoise_2.0.1 Cairo_1.6-2
## [17] Rsamtools_2.20.0 RCurl_1.98-1.16
## [19] htmltools_0.5.8.1 S4Arrays_1.4.1
## [21] curl_5.2.1 SparseArray_1.4.8
## [23] gridGraphics_0.5-1 sass_0.4.9
## [25] bslib_0.8.0 legendry_0.2.0
## [27] zoo_1.8-12 cachem_1.1.0
## [29] GenomicAlignments_1.40.0 lifecycle_1.0.4
## [31] iterators_1.0.14 pkgconfig_2.0.3
## [33] Matrix_1.7-0 R6_2.5.1
## [35] fastmap_1.2.0 santoku_1.0.0
## [37] GenomeInfoDbData_1.2.12 tikzDevice_0.12.6
## [39] MatrixGenerics_1.16.0 clue_0.3-65
## [41] digest_0.6.36 ggbezier_0.1.0
## [43] colorspace_2.1-1 S4Vectors_0.42.1
## [45] GenomicRanges_1.56.1 labeling_0.4.3
## [47] tinytable_0.7.0 fansi_1.0.6
## [49] timechange_0.3.0 httr_1.4.7
## [51] polyclip_1.10-7 abind_1.4-5
## [53] compiler_4.4.1 bit64_4.0.5
## [55] withr_3.0.1 doParallel_1.0.17
## [57] BiocParallel_1.38.0 ggupset_0.4.0
## [59] highr_0.11 ggforce_0.4.2
## [61] MASS_7.3-60.2 DelayedArray_0.30.1
## [63] lemur_1.2.0 rjson_0.2.21
## [65] wesanderson_0.3.7 tools_4.4.1
## [67] vipor_0.4.7 filehash_2.4-6
## [69] beeswarm_0.4.0 glmGamPoi_1.16.0
## [71] restfulr_0.0.15 shadowtext_0.1.4
## [73] grid_4.4.1 cluster_2.1.6
## [75] generics_0.1.3 gtable_0.3.5
## [77] strapgod_0.0.4.9000 tzdb_0.4.0
## [79] ungeviz_0.1.0 data.table_1.15.4
## [81] hms_1.1.3 utf8_1.2.4
## [83] XVector_0.44.0 BiocGenerics_0.50.0
## [85] ggrepel_0.9.6 foreach_1.5.2
## [87] pillar_1.9.0 vroom_1.6.5
## [89] yulab.utils_0.1.5 splines_4.4.1
## [91] circlize_0.4.16 tweenr_2.0.3
## [93] lattice_0.22-6 bit_4.0.5
## [95] renv_1.0.7 rtracklayer_1.64.0
## [97] tidyselect_1.2.1 SingleCellExperiment_1.26.0
## [99] ComplexHeatmap_2.20.0 Biostrings_2.72.1
## [101] knitr_1.48 IRanges_2.38.1
## [103] SummarizedExperiment_1.34.0 stats4_4.4.1
## [105] xfun_0.50.5 Biobase_2.64.0
## [107] matrixStats_1.3.0 stringi_1.8.4
## [109] UCSC.utils_1.0.0 yaml_2.3.10
## [111] evaluate_0.24.0 codetools_0.2-20
## [113] BiocManager_1.30.23 ggplotify_0.1.2
## [115] cli_3.6.3 munsell_0.5.1
## [117] jquerylib_0.1.4 Rcpp_1.0.13
## [119] GenomeInfoDb_1.40.1 tidylog_1.1.0
## [121] png_0.1-8 XML_3.99-0.17
## [123] ggrastr_1.0.2 parallel_4.4.1
## [125] assertthat_0.2.1 ggh4x_0.2.8
## [127] plyranges_1.24.0 bitops_1.0-8
## [129] scales_1.3.0 plotgardener_1.10.2
## [131] crayon_1.5.3 clisymbols_1.2.0
## [133] GetoptLong_1.0.5 rlang_1.1.4
## [135] cowplot_1.1.3